{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c11fde21",
   "metadata": {},
   "source": [
    "# Multimodal video search using CLIP and LanceDB\r\n",
    "We used LanceDB to store frames every thirty seconds and the title of 13000+ videos, 5 random from each top category from the Youtube 8M dataset.\r\n",
    "Then, we used the CLIP model to embed frames and titles together. With LanceDB, we can perform embedding, keyword, and SQL search on these videosjpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e647a892-de71-4842-a26d-035d628792d3",
   "metadata": {},
   "source": [
    "### Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69fb1627",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --quiet -U lancedb\n",
    "!pip install --quiet gradio transformers torch torchvision duckdb\n",
    "!pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d53ade3",
   "metadata": {},
   "source": [
    "## First run setup: Download data and pre-process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b7e97f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import duckdb\n",
    "import lancedb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ba75742",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://vectordb-recipes.s3.us-west-2.amazonaws.com/multimodal_video_lance.tar.gz\n",
    "!tar -xvf multimodal_video_lance.tar.gz\n",
    "!mkdir -p data/video-lancedb\n",
    "!mv multimodal_video.lance data/video-lancedb/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2fcbf61",
   "metadata": {},
   "source": [
    "## Create / Open LanceDB Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3317a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "db = lancedb.connect(\"data/video-lancedb\")\n",
    "tbl = db.open_table(\"multimodal_video\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c6c7dfb-3504-4875-8f8b-a857bbb56727",
   "metadata": {},
   "source": [
    "## Create CLIP embedding function for the text\r\n",
    "![clip](https://miro.medium.com/v2/resize:fit:3662/1*tg7akErlMSyCLQxrMtQIYw.png)\r\n",
    "*CLIP model Architecuture.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8331d87",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast\n",
    "\n",
    "MODEL_ID = \"openai/clip-vit-base-patch32\"\n",
    "\n",
    "tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)\n",
    "model = CLIPModel.from_pretrained(MODEL_ID)\n",
    "processor = CLIPProcessor.from_pretrained(MODEL_ID)\n",
    "\n",
    "\n",
    "def embed_func(query):\n",
    "    inputs = tokenizer([query], padding=True, return_tensors=\"pt\")\n",
    "    text_features = model.get_text_features(**inputs)\n",
    "    return text_features.detach().numpy()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e26207b-20ce-41b6-8fe7-6050582881cc",
   "metadata": {},
   "source": [
    "## Search functions for Gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9f9e75a8-8e19-487c-8bab-bd2ffc9cbb06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_video_vectors(query):\n",
    "    emb = embed_func(query)\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('data/video-lancedb')\\n\"\n",
    "        \"tbl = db.open_table('multimodal_video')\\n\\n\"\n",
    "        f\"embedding = embed_func('{query}')\\n\"\n",
    "        \"tbl.search(embedding).limit(9).to_pandas()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(emb).limit(9).to_pandas()), code)\n",
    "\n",
    "\n",
    "# function to find the search for the video keywords from lancedb\n",
    "def find_video_keywords(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('data/video-lancedb')\\n\"\n",
    "        \"tbl = db.open_table('multimodal_video')\\n\\n\"\n",
    "        f\"tbl.search('{query}').limit(9).to_pandas()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(query).limit(9).to_pandas()), code)\n",
    "\n",
    "\n",
    "# create a SQL command to retrieve the video from the db\n",
    "def find_video_sql(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"import duckdb\\n\"\n",
    "        \"db = lancedb.connect('data/video-lancedb')\\n\"\n",
    "        \"tbl = db.open_table('multimodal_video')\\n\\n\"\n",
    "        \"videos = tbl.to_lance()\\n\"\n",
    "        f\"duckdb.sql('{query}').to_pandas()\"\n",
    "    )\n",
    "    videos = tbl.to_lance()\n",
    "    return (_extract(duckdb.sql(query).to_pandas()), code)\n",
    "\n",
    "\n",
    "# extract the video from the df\n",
    "def _extract(df):\n",
    "    video_id_col = \"video_id\"\n",
    "    start_time_col = \"start_time\"\n",
    "    grid_html = '<div style=\"display: grid; grid-template-columns: repeat(3, 1fr); grid-gap: 20px;\">'\n",
    "\n",
    "    for _, row in df.iterrows():\n",
    "        iframe_code = f'<iframe width=\"100%\" height=\"315\" src=\"https://www.youtube.com/embed/{row[video_id_col]}?start={str(row[start_time_col])}\" title=\"YouTube video player\" frameborder=\"0\" allow=\"accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture\" allowfullscreen></iframe>'\n",
    "        grid_html += f'<div style=\"width: 100%;\">{iframe_code}</div>'\n",
    "\n",
    "    grid_html += \"</div>\"\n",
    "    return grid_html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10b8de6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_video_vectors(query):\n",
    "    emb = embed_func(query)\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('data/video-lancedb')\\n\"\n",
    "        \"tbl = db.open_table('multimodal_video')\\n\\n\"\n",
    "        f\"embedding = embed_func('{query}')\\n\"\n",
    "        \"tbl.search(embedding).limit(9).to_pandas()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(emb).limit(9).to_pandas()), code)\n",
    "\n",
    "\n",
    "def find_video_keywords(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('data/video-lancedb')\\n\"\n",
    "        \"tbl = db.open_table('multimodal_video')\\n\\n\"\n",
    "        f\"tbl.search('{query}').limit(9).to_pandas()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(query).limit(9).to_pandas()), code)\n",
    "\n",
    "\n",
    "def find_video_sql(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"import duckdb\\n\"\n",
    "        \"db = lancedb.connect('data/video-lancedb')\\n\"\n",
    "        \"tbl = db.open_table('multimodal_video')\\n\\n\"\n",
    "        \"videos = tbl.to_lance()\\n\"\n",
    "        f\"duckdb.sql('{query}').to_pandas()\"\n",
    "    )\n",
    "    videos = tbl.to_lance()\n",
    "    return (_extract(duckdb.sql(query).to_pandas()), code)\n",
    "\n",
    "\n",
    "def _extract(df):\n",
    "    video_id_col = \"video_id\"\n",
    "    start_time_col = \"start_time\"\n",
    "    grid_html = '<div style=\"display: grid; grid-template-columns: repeat(3, 1fr); grid-gap: 20px;\">'\n",
    "\n",
    "    for _, row in df.iterrows():\n",
    "        iframe_code = f'<iframe width=\"100%\" height=\"315\" src=\"https://www.youtube.com/embed/{row[video_id_col]}?start={str(row[start_time_col])}\" title=\"YouTube video player\" frameborder=\"0\" allow=\"accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture\" allowfullscreen></iframe>'\n",
    "        grid_html += f'<div style=\"width: 100%;\">{iframe_code}</div>'\n",
    "\n",
    "    grid_html += \"</div>\"\n",
    "    return grid_html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e3a8ac6-2266-4cf1-bce6-28c473328022",
   "metadata": {},
   "source": [
    "## Setup Gradio interface\n",
    "![gradio]()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f40300",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\n",
    "        \"\"\"\n",
    "            # Multimodal video search using CLIP and LanceDB\n",
    "            We used LanceDB to store frames every thirty seconds and the title of 13000+ videos, 5 random from each top category from the Youtube 8M dataset. \n",
    "            Then, we used the CLIP model to embed frames and titles together. With LanceDB, we can perform embedding, keyword, and SQL search on these videos.\n",
    "            \"\"\"\n",
    "    )\n",
    "    with gr.Row():\n",
    "        with gr.Tab(\"Embeddings\"):\n",
    "            vector_query = gr.Textbox(value=\"retro gaming\", show_label=False)\n",
    "            b1 = gr.Button(\"Submit\")\n",
    "        with gr.Tab(\"Keywords\"):\n",
    "            keyword_query = gr.Textbox(value=\"ninja turtles\", show_label=False)\n",
    "            b2 = gr.Button(\"Submit\")\n",
    "        with gr.Tab(\"SQL\"):\n",
    "            sql_query = gr.Textbox(\n",
    "                value=\"SELECT DISTINCT video_id, * from videos WHERE start_time > 0 LIMIT 9\",\n",
    "                show_label=False,\n",
    "            )\n",
    "            b3 = gr.Button(\"Submit\")\n",
    "    with gr.Row():\n",
    "        code = gr.Code(label=\"Code\", language=\"python\")\n",
    "    with gr.Row():\n",
    "        gallery = gr.HTML()\n",
    "\n",
    "    b1.click(find_video_vectors, inputs=vector_query, outputs=[gallery, code])\n",
    "    b2.click(find_video_keywords, inputs=keyword_query, outputs=[gallery, code])\n",
    "    b3.click(find_video_sql, inputs=sql_query, outputs=[gallery, code])\n",
    "\n",
    "demo.launch()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  },
  "vscode": {
   "interpreter": {
    "hash": "511a7c77cb034b09af5465c01316a0f4bb20176d139e60e6d7915f9a637a5037"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
